from __future__ import absolute_import, division, print_function

import torch
import torch.nn as nn
import torchvision.models as models
from cbml_benchmark.modeling import registry
import torch.nn.functional as F
import numpy as np


@registry.BACKBONES.register('resnet50')
class ResNet50(nn.Module):

    def __init__(self):
        super(ResNet50, self).__init__()
        self.model = models.resnet50(pretrained=True)

        for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
            module.eval()
            module.train = lambda _: None
        # self.model.last_linear = torch.nn.Conv2d(self.model.last_linear.in_features, 512, 1)
        self.num = [64, 32, 1, 8]
        self.num_clusters = [self.num[0] * self.num[0], self.num[1] * self.num[1], self.num[2] * self.num[2],
                             self.num[3] * self.num[3]]

        # self.val2 = torch.randn(self.num_clusters[2], 2048)
        self.val3 = torch.randn(self.num_clusters[3], 2048)
        self.apply(self._init_centroids)

    def _init_centroids(self, m):
        # self.centroids2 = nn.Parameter((self.val2).to("cuda"))
        self.centroids3 = nn.Parameter((self.val3).to("cuda"))

    # def ra2(self,x,index):
    #     N, C1, H, W = x.shape
    #     x_flatten = x.view(N, C1, -1)
    #
    #     sim = (torch.matmul(x_flatten.unsqueeze(0).permute(1, 0, 3, 2),
    #                         F.normalize(self.centroids2, p=2, dim=1).permute(1, 0).unsqueeze(0).unsqueeze(0)).permute(0,1,3,2)/ np.sqrt(self.num_clusters[index])).squeeze(1) #
    #     sim = torch.exp(sim)
    #     sim = torch.log(1+sim)
    #     ra = torch.zeros([N, self.num_clusters[index], C1], dtype=x.dtype, layout=x.layout, device=x.device)
    #     for C in range(self.num_clusters[index]):
    #         residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
    #                    F.normalize(self.centroids2[C:C + 1, :], p=2, dim=1).expand(x_flatten.size(-1), -1, -1).permute(1,2,0).unsqueeze(0)
    #         residual *= sim[:, C:C + 1, :].unsqueeze(2)
    #         ra[:, C:C + 1, :] = residual.sum(dim=-1) / C1
    #     ra = F.normalize(ra, p=2, dim=2)
    #     ra = ra.permute(0,2,1).view(N,C1,self.num[index],self.num[index])
    #     return ra

    def ra3(self,x,index):
        N, C1, H, W = x.shape
        x_flatten = x.view(N, C1, -1)
        x_flatten = F.normalize(x_flatten, p=2, dim=1)

        sim = (torch.matmul(x_flatten.unsqueeze(0).permute(1, 0, 3, 2),
                            F.normalize(self.centroids3, p=2, dim=1).permute(1, 0).unsqueeze(0).unsqueeze(0)).permute(0,1,3,2)/ np.sqrt(self.num_clusters[index])).squeeze(1) #
        sim = torch.exp(sim)
        sim = torch.log(1+sim)
        ra = torch.zeros([N, self.num_clusters[index], C1], dtype=x.dtype, layout=x.layout, device=x.device)
        for C in range(self.num_clusters[index]):
            residual = x_flatten.unsqueeze(0).permute(1, 0, 2, 3) - \
                       F.normalize(self.centroids3[C:C + 1, :], p=2, dim=1).expand(x_flatten.size(-1), -1, -1).permute(1,2,0).unsqueeze(0)
            residual *= sim[:, C:C + 1, :].unsqueeze(2)
            ra[:, C:C + 1, :] = residual.sum(dim=-1) / C1
        ra = F.normalize(ra, p=2, dim=2)
        ra = ra.permute(0,2,1).view(N,C1,self.num[index],self.num[index])
        return ra

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        # x1 = self.ra2(x, 2).contiguous()
        x2 = self.ra3(x,3).contiguous()

        return x2, x2, self.centroids3, self.centroids3

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'last_linear' in i:
                continue
            self.model.state_dict()[i].copy_(param_dict[i])

@registry.BACKBONES.register('resnet18')
class ResNet18(nn.Module):

    def __init__(self):
        super(ResNet18, self).__init__()
        self.model = models.resnet18(pretrained=True)

        for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
            module.eval()
            module.train = lambda _: None

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = x.view(x.size(0), -1)
        # x = self.model.fc(x)  --remove
        return x

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'last_linear' in i:
                continue
            self.model.state_dict()[i].copy_(param_dict[i])

@registry.BACKBONES.register('resnet34')
class ResNet34(nn.Module):

    def __init__(self):
        super(ResNet34, self).__init__()
        self.model = models.resnet34(pretrained=True)

        for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
            module.eval()
            module.train = lambda _: None

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = x.view(x.size(0), -1)
        # x = self.model.fc(x)  --remove
        return x

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'last_linear' in i:
                continue
            self.model.state_dict()[i].copy_(param_dict[i])

@registry.BACKBONES.register('resnet101')
class ResNet101(nn.Module):

    def __init__(self):
        super(ResNet101, self).__init__()
        self.model = models.resnet101(pretrained=True)

        for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
            module.eval()
            module.train = lambda _: None

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = x.view(x.size(0), -1)
        # x = self.model.fc(x)  --remove
        return x

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'last_linear' in i:
                continue
            self.model.state_dict()[i].copy_(param_dict[i])

@registry.BACKBONES.register('resnet152')
class ResNet152(nn.Module):

    def __init__(self):
        super(ResNet152, self).__init__()
        self.model = models.resnet152(pretrained=True)

        for module in filter(lambda m: type(m) == nn.BatchNorm2d, self.model.modules()):
            module.eval()
            module.train = lambda _: None

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)

        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        x = self.model.avgpool(x)
        x = x.view(x.size(0), -1)
        # x = self.model.fc(x)  --remove
        return x

    def load_param(self, model_path):
        param_dict = torch.load(model_path)
        for i in param_dict:
            if 'last_linear' in i:
                continue
            self.model.state_dict()[i].copy_(param_dict[i])